# -*- coding: utf-8 -*-
# vim: tabstop=4 shiftwidth=4 softtabstop=4
#
# Copyright (C) 2023, GEM Foundation
#
# OpenQuake is free software: you can redistribute it and/or modify it
# under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# OpenQuake is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with OpenQuake.  If not, see <http://www.gnu.org/licenses/>.
"""
Median spectrum post-processor
"""

import logging
import numpy as np
from openquake.baselib import hdf5, sap, parallel, general, performance
from openquake.hazardlib import contexts

U32 = np.uint32
I64 = np.int64
F32 = np.float32


def set_imls(cmaker, uhs):
    """
    Replace the imtls with the values in the uniform hazard spectrum (P
    levels per IMT).

    :param cmaker: a ContextMaker
    :param uhs: an array of shape (M, P)
    """
    imtls = {}
    loglevs = {}
    for imt, imls in zip(cmaker.imts, uhs):
        imtls[imt.string] = imls
        loglevs[imt.string] = np.log(imls)
    cmaker.imtls = general.DictArray(imtls)
    cmaker.loglevels = general.DictArray(loglevs)
    return cmaker


def get_mea_sig_wei(cmaker, ctx, uhs):
    """
    :param cmaker: a ContextMaker instance with G gsims and M imts
    :param ctx: a context array of size C
    :param uhs: an array of shape (M, P)
    :returns: mean[G, M, C], sigma[G, M, C], weights[G, M, C, P]
    """
    M = len(cmaker.imts)
    P = len(cmaker.poes)
    G = len(cmaker.gsims)
    C = len(ctx)
    # reduce the levels to P levels per IMT
    cmaker = set_imls(cmaker, uhs)
    weight = np.empty((G, M, C, P), np.float32)
    mea_ = np.empty((G, M, C), np.float32)
    sig_ = np.empty((G, M, C), np.float32)
    tau_ = np.empty((G, M, C), np.float32)
    start = 0
    for poes, mea, sig, tau, ctxt in cmaker.gen_poes(ctx):
        c, _, _ = poes.shape  # L = M * P
        slc = slice(start, start + c)
        mea_[:, :, slc] = mea
        sig_[:, :, slc] = sig
        tau_[:, :, slc] = tau
        start += c
        ocr = cmaker.get_occ_rates(ctxt)
        for g, w in enumerate(cmaker.wei):
            poes_g = poes[:, :, g].reshape(c, M, P)
            # NB: vectorizing the loops on M, P improves nothing;
            # the important loop is the one on C (up to 6000 elements
            # for Canada) which is vectorized
            for p, poe in enumerate(cmaker.poes):
                for m, imt in enumerate(cmaker.imtls):
                    weight[g, m, slc, p] = ocr * poes_g[:, m, p] / poe * w
    return mea_, sig_, tau_, weight


def tr(arr):
    return arr.transpose(2, 0, 1)


def check_rup_unique(spec_disagg):
    """
    Make sure the rupture IDs are unique
    """
    rupids = []
    for dset in spec_disagg.values():
        rupids.append(dset['rup_id'])
    rupids = np.concatenate(rupids)
    U = len(np.unique(rupids))
    if U < len(rupids):
        raise RuntimeError('The rupture IDs are not unique!')


# NB: we are ignoring IMT-dependent weight
def compute_median_spectrum(
        cmaker, context, uhs, monitor=performance.Monitor()):
    """
    For a given group, computes the median hazard spectrum using a weighted
    mean based on the poes.

    :param cmaker: ContextMaker for a group of sources
    :param context: context array generated by the group of sources
    :param uhs: array of Uniform Hazard Spectra of shape (N, M, P)
    """
    _N, M, P = uhs.shape
    sids = np.unique(context.sids)
    one_site_poe = len(sids) == 1 and P == 1
    for site_id in sids:
        ctx = context[context.sids == site_id]
        grp_id = ctx[0]['grp_id']
        mea, sig, tau, wei = get_mea_sig_wei(cmaker, ctx, uhs[site_id])
        out = np.empty((3, M, P))  # <mea>, <sig>, tot_w
        out[0] = np.einsum("gmup,gmu->mp", wei, mea)
        out[1] = np.einsum("gmup,gmu->mp", wei, sig)
        out[2] = wei.sum(axis=(0, 2))
        yield {(grp_id, site_id): out}
        if one_site_poe:
            ok = wei.sum(axis=(0, 1, 3)) > 0
            arr = general.compose_arrays(
                rup_id=ctx.rup_id, mag=ctx.mag, rrup=ctx.rrup,
                occurrence_rate=ctx.occurrence_rate,
                mea=tr(mea), sig=tr(sig), tau=tr(tau), wei=tr(wei[:, :, :, 0]))
            yield {(grp_id, -1): [arr[ok]]}


# NB: we are ignoring IMT-dependent weights
def main(dstore, csm):
    """
    Compute the median hazard spectrum for the reference poe,
    starting from the already stored mean hazard spectrum.

    :param dstore: DataStore of the parent calculation
    :param csm: CompositeRiskModel
    """
    # consistency checks
    oq = dstore["oqparam"]
    periods = [imt.period for imt in oq.imt_periods()]
    N = len(dstore["sitecol"])
    M = len(oq.imtls)
    assert oq.investigation_time == 1, oq.investigation_time
    assert len(periods) == M, 'IMTs different from PGA, SA'
    assert N <= oq.max_sites_disagg, N
    logging.warning("Median spectrum calculations are still " "experimental")

    # read the precomputed mean hazard spectrum
    ref_uhs = dstore.sel("hmaps-stats", stat="mean")[:, 0]  # shape NSMP -> NMP
    cmakers = contexts.read_cmakers(dstore).to_array()
    G = {grp_id: len(cm.gsims) for grp_id, cm in enumerate(cmakers)}
    ctx_by_grp = contexts.read_ctx_by_grp(dstore)
    # check_rup_unique(ctx_by_grp)
    totsize = sum(len(ctx) * G[grp_id] for grp_id, ctx in ctx_by_grp.items())
    blocksize = totsize / (oq.concurrent_tasks or 1)
    smap = parallel.Starmap(compute_median_spectrum, h5=dstore)
    for grp_id, ctx in ctx_by_grp.items():
        # reduce the levels to 1 level per IMT
        cmaker = cmakers[grp_id]
        splits = np.ceil(len(ctx) * G[grp_id] / blocksize)
        for ctxt in np.array_split(ctx, splits):
            smap.submit((cmaker, ctxt, ref_uhs))
    res = smap.reduce()

    # save the median_spectrum
    Gr = len(csm.src_groups)  # number of groups
    P = len(oq.poes)
    median_spectra = np.zeros((Gr, N, 3, M, P), np.float32)
    tot_w = np.zeros((N, M, P))

    # create median_spectrum_disagg datasets
    if N == 1 and P == 1:
        for grp_id, cm in enumerate(cmakers):
            G = len(cm.gsims)
            dtlist = [('rup_id', I64),
                      ('mag', F32),
                      ('rrup', F32),
                      ('occurrence_rate', F32)]
            dt = (F32, (M,))
            for g in range(G):
                dtlist.append((f'mea{g}', dt))
            for g in range(G):
                dtlist.append((f'sig{g}', dt))
            for g in range(G):
                dtlist.append((f'tau{g}', dt))
            for g in range(G):
                dtlist.append((f'wei{g}', dt))
            name = f"median_spectrum_disagg/grp{grp_id}"
            logging.info('Creating %s', name)
            dstore.create_dset(name, dtlist)

    for (grp_id, site_id), out in res.items():
        if site_id == -1:  # median_spectrum_disagg
            for arr in out:
                hdf5.extend(dstore[f'median_spectrum_disagg/grp{grp_id}'], arr)
        else:
            median_spectra[grp_id, site_id] = out
            tot_w[site_id] += out[2]
    dstore.create_dset("median_spectra", median_spectra)
    dstore.set_shape_descr("median_spectra", grp_id=Gr,
                           site_id=N, kind=['mea', 'sig', 'tau', 'wei'],
                           period=periods, poe=oq.poes)

    # sanity check on the weights
    for p, poe in enumerate(oq.poes):
        maxw = tot_w[:, :, p].max()
        logging.info(f'{poe=} {maxw=}')
        if (np.abs(tot_w[:, :, p] - 1) > .01).any():
            raise ValueError(
                f'The weights sum up to {maxw:.3f} != 1: perhaps the '
                f'hazard curve is not invertible around {poe=}')

    # sanity check on the rup_ids
    if N == 1 and P == 1:
        check_rup_unique(dstore['median_spectrum_disagg'])

if __name__ == "__main__":
    sap.run(main)
